package org.gavin.sharding.utils;

import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.driver.jdbc.core.datasource.ShardingSphereDataSource;
import org.apache.shardingsphere.infra.config.rule.RuleConfiguration;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.sharding.api.config.ShardingRuleConfiguration;
import org.apache.shardingsphere.sharding.api.config.rule.ShardingTableRuleConfiguration;
import org.apache.shardingsphere.spring.boot.datasource.DataSourceMapSetter;
import org.springframework.util.ReflectionUtils;

import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.YearMonth;
import java.time.format.DateTimeFormatter;
import java.util.*;

/**
 * ShardingAlgorithmTool
 *
 * @author grl
 * @date 2024/2/4
 */
@Slf4j
public class ShardingAlgorithmTool {

    /** 表分片符号，例：t_user_202201 中，分片符号为 "_" */
    private static final String TABLE_SPLIT_SYMBOL = "_";
    /**
     * 获取所有表名
     * @return 表名集合
     * @param logicTableName 逻辑表
     */
    public static List<String> getAllTableNameBySchema(String logicTableName) {
        List<String> tableNames = new ArrayList<>();
        Map<String, DataSource> dataSourceMap = DataSourceMapSetter.getDataSourceMap(SpringUtil.getEnvironment());
        if(!dataSourceMap.isEmpty()){
            dataSourceMap.forEach((key,dataSource)->{
                try (Connection conn = dataSource.getConnection();
                     Statement st = conn.createStatement()) {
                    try (ResultSet rs = st.executeQuery("show TABLES like '" + logicTableName + TABLE_SPLIT_SYMBOL + "%'")) {
                        while (rs.next()) {
                            String tableName = rs.getString(1);
                            // 匹配分表格式 例：^(t\_contract_\d{6})$
                            if (tableName != null && tableName.matches(String.format("^(%s\\d{6})$", logicTableName + TABLE_SPLIT_SYMBOL))) {
                                tableNames.add(rs.getString(1));
                            }
                        }
                    }
                } catch (SQLException e) {
                    log.error(">>>>>>>>>> 【ERROR】sharding数据源链接失败!数据源名称 {}",key,e);
                }
            });
            return tableNames;
        }
        log.error(">>>>>>>>>> 【ERROR】sharding数据源不能为空");
        throw new IllegalArgumentException("数据库连接配置有误，请稍后重试");
    }

    /**
     * 刷新ActualDataNodes
     */
    private static void updateShardRuleActualDataNodes(ShardingSphereDataSource dataSource, String logicTableName, String newActualDataNodes) throws NoSuchFieldException {
        // Context manager.
        Field contextManagerField = dataSource.getClass().getDeclaredField("contextManager");
        ContextManager contextManager = (ContextManager) ReflectionUtils.getField(contextManagerField, dataSource);

        // Rule configuration.
        ShardingSphereMetaData shardingSphereMetaData = contextManager
                .getMetaDataContexts()
                .getMetaData();
        Collection<RuleConfiguration> newRuleConfigList = new LinkedList<>();
        Collection<RuleConfiguration> oldRuleConfigList = shardingSphereMetaData
                .getGlobalRuleMetaData()
                .getConfigurations();

        for (RuleConfiguration oldRuleConfig : oldRuleConfigList) {
            if (oldRuleConfig instanceof ShardingRuleConfiguration) {

                // Algorithm provided sharding rule configuration
                ShardingRuleConfiguration oldAlgorithmConfig = (ShardingRuleConfiguration) oldRuleConfig;
                ShardingRuleConfiguration newAlgorithmConfig = new ShardingRuleConfiguration();

                // Sharding table rule configuration Collection
                Collection<ShardingTableRuleConfiguration> newTableRuleConfigList = new LinkedList<>();
                Collection<ShardingTableRuleConfiguration> oldTableRuleConfigList = oldAlgorithmConfig.getTables();

                oldTableRuleConfigList.forEach(oldTableRuleConfig -> {
                    if (logicTableName.equals(oldTableRuleConfig.getLogicTable())) {
                        ShardingTableRuleConfiguration newTableRuleConfig = new ShardingTableRuleConfiguration(oldTableRuleConfig.getLogicTable(), newActualDataNodes);
                        newTableRuleConfig.setTableShardingStrategy(oldTableRuleConfig.getTableShardingStrategy());
                        newTableRuleConfig.setDatabaseShardingStrategy(oldTableRuleConfig.getDatabaseShardingStrategy());
                        newTableRuleConfig.setKeyGenerateStrategy(oldTableRuleConfig.getKeyGenerateStrategy());

                        newTableRuleConfigList.add(newTableRuleConfig);
                    } else {
                        newTableRuleConfigList.add(oldTableRuleConfig);
                    }
                });

                newAlgorithmConfig.setTables(newTableRuleConfigList);
                newAlgorithmConfig.setAutoTables(oldAlgorithmConfig.getAutoTables());
                newAlgorithmConfig.setBindingTableGroups(oldAlgorithmConfig.getBindingTableGroups());
                newAlgorithmConfig.setBroadcastTables(oldAlgorithmConfig.getBroadcastTables());
                newAlgorithmConfig.setDefaultDatabaseShardingStrategy(oldAlgorithmConfig.getDefaultDatabaseShardingStrategy());
                newAlgorithmConfig.setDefaultTableShardingStrategy(oldAlgorithmConfig.getDefaultTableShardingStrategy());
                newAlgorithmConfig.setDefaultKeyGenerateStrategy(oldAlgorithmConfig.getDefaultKeyGenerateStrategy());
                newAlgorithmConfig.setDefaultShardingColumn(oldAlgorithmConfig.getDefaultShardingColumn());
                newAlgorithmConfig.setShardingAlgorithms(oldAlgorithmConfig.getShardingAlgorithms());
                newAlgorithmConfig.setKeyGenerators(oldAlgorithmConfig.getKeyGenerators());

                newRuleConfigList.add(newAlgorithmConfig);
            }
        }

        // update context
        String schemaName = "logic_db";
        contextManager.alterRuleConfiguration(schemaName, newRuleConfigList);
    }

    /**
     * 创建分表2
     * @param logicTableName  逻辑表
     * @param resultTableName 真实表名，例：t_user_202201
     * @return 创建结果（true创建成功，false未创建）
     */
    public static boolean createShardingTable(String logicTableName, String resultTableName) {
        // 根据日期判断，当前月份之后分表不提前创建
        String month = resultTableName.replace(logicTableName + TABLE_SPLIT_SYMBOL,"");
        YearMonth shardingMonth = YearMonth.parse(month, DateTimeFormatter.ofPattern("yyyyMM"));
        if (shardingMonth.isAfter(YearMonth.now())) {
            return false;
        }

        synchronized (logicTableName.intern()) {
            // 缓存中无此表，则建表并添加缓存
            executeSql(Collections.singletonList("CREATE TABLE IF NOT EXISTS `" + resultTableName + "` LIKE `" + logicTableName + "`;"));
        }
        return true;
    }

    /**
     * 执行SQL
     * @param sqlList SQL集合
     */
    private static void executeSql(List<String> sqlList) {
        Map<String, DataSource> dataSourceMap = DataSourceMapSetter.getDataSourceMap(SpringUtil.getEnvironment());
        dataSourceMap.forEach((key,dataSource)->{
            try (Connection conn = dataSource.getConnection()) {
                try (Statement st = conn.createStatement()) {
                    conn.setAutoCommit(false);
                    for (String sql : sqlList) {
                        st.execute(sql);
                    }
                } catch (Exception e) {
                    conn.rollback();
                    log.error(">>>>>>>>>> 【ERROR】数据表创建执行失败，请稍后重试，原因：{}", e.getMessage(), e);
                }
            } catch (SQLException e) {
                log.error(">>>>>>>>>> 【ERROR】数据库连接失败，请稍后重试，原因：{}", e.getMessage(), e);
                throw new IllegalArgumentException("数据库连接失败，请稍后重试");
            }
        });
    }
}
